"""
Phase 11: Pre-Trend Diagnosis
===============================
Three approaches to address contaminated pre-trends in GDP growth and TFP:

A) Control for recipient demographics (Z₁, Z₂, Z₃) in the LP.
   If pre-trend cleans up → confounding through demographics, not flows.
   If pre-trend persists → problem is deeper.

B) De-mean within country (country fixed effects via demeaning).
   If pre-trend cleans up → cross-sectional sorting was the issue.
   If persists → within-country time-series correlation.

C) Use growth residual: first regress outcome on Z, take residual,
   then run LP on the residual.
   Tests whether demographic inflows predict outcomes *beyond*
   what demographics themselves predict.

Also run these on Investment/GDP as a sanity check — should stay clean.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

MAX_HORIZON = 5
PRE_HORIZONS = [-3, -2, -1]
BASE_CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']

OECD = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHE', 'CHL', 'COL', 'CRI', 'CZE', 'DEU',
    'DNK', 'ESP', 'EST', 'FIN', 'FRA', 'GBR', 'GRC', 'HUN', 'IRL', 'ISL',
    'ISR', 'ITA', 'JPN', 'KOR', 'LTU', 'LUX', 'LVA', 'MEX', 'NLD', 'NOR',
    'NZL', 'POL', 'PRT', 'SVK', 'SVN', 'SWE', 'TUR', 'USA',
}

# Test on these outcomes
OUTCOMES = {
    'rgdp_growth': ('GDP Growth', 'growth'),              # contaminated
    'delta_log_tfp': ('TFP Growth', 'growth'),             # contaminated
    'gross_fixed_investment_gdp': ('Δ Investment/GDP', 'level'),  # clean (sanity)
}


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def build_horizon(df, y_var, h, var_type):
    """Build outcome at horizon h. Level vars get first-differenced."""
    df = df.sort_values(['iso3', 'year']).copy()

    if var_type == 'level':
        # First-differenced: y_{t+h} - y_{t-1}
        y_baseline = df.groupby('iso3')[y_var].shift(1)
        if h >= 0:
            y_forward = df.groupby('iso3')[y_var].shift(-h)
        else:
            y_forward = df.groupby('iso3')[y_var].shift(abs(h))
        col = f'd_{y_var}_h{h}'
        df[col] = y_forward - y_baseline
    else:
        # Growth: cumulative
        col = f'{y_var}_h{h}'
        if h == 0:
            df[col] = df[y_var]
        elif h > 0:
            df[col] = (
                df.groupby('iso3')[y_var]
                .transform(lambda s: s.rolling(window=h+1, min_periods=h+1).sum().shift(-h))
            )
        else:
            abs_h = abs(h)
            df[col] = (
                df.groupby('iso3')[y_var]
                .transform(lambda s: s.rolling(window=abs_h, min_periods=abs_h).sum().shift(1))
            )
    return df, col


def run_lp(df, y_col, x_vars, controls, key_var):
    """Run LP, return key_var results."""
    all_vars = [y_col] + x_vars + controls + ['iso3', 'year']
    avail = [c for c in all_vars if c in df.columns]
    sub = df[avail].dropna()
    actual_x = [v for v in x_vars + controls if v in sub.columns]

    if len(sub) < 30:
        return None

    gls = PanelGLS()
    try:
        gls.fit(sub[y_col].values, sub[actual_x].values,
                sub['iso3'].values, sub['year'].values)
    except Exception:
        return None

    idx = actual_x.index(key_var) if key_var in actual_x else None
    if idx is None:
        return None

    return {
        'coef': gls.beta[idx],
        'se': gls.se[idx],
        'p': gls.pvalues[idx],
        'ci_lo': gls.beta[idx] - 1.96 * gls.se[idx],
        'ci_hi': gls.beta[idx] + 1.96 * gls.se[idx],
        'n_obs': gls.n_obs,
        'r_squared': gls.r_squared,
    }


def run_irf(df, y_var, var_type, x_vars, controls, key_var):
    """Run full IRF with pre-trends."""
    all_h = PRE_HORIZONS + list(range(MAX_HORIZON + 1))
    irf = {}
    for h in all_h:
        df_h, y_col = build_horizon(df, y_var, h, var_type)
        irf[h] = run_lp(df_h, y_col, x_vars, controls, key_var)
    return irf


def print_irf(irf, label):
    """Print IRF as compact table."""
    print(f"\n  {label}")
    pre_sig = 0
    pre_tot = 0
    print(f"  {'h':>3s}  {'β':>10s}  {'SE':>8s}  {'p':>8s}  {'N':>5s}")
    print(f"  {'─'*3}  {'─'*10}  {'─'*8}  {'─'*8}  {'─'*5}")
    for h in sorted(irf.keys()):
        pt = irf[h]
        if pt is None:
            print(f"  {h:3d}       n/a")
            continue
        s = stars(pt['p'])
        marker = ' ← PRE' if h < 0 else ''
        print(f"  {h:3d}  {pt['coef']:10.5f}{s:3s}  {pt['se']:8.5f}  {pt['p']:8.4f}  {pt['n_obs']:5d}{marker}")
        if h < 0:
            pre_tot += 1
            if pt['p'] < 0.1:
                pre_sig += 1
    status = 'CLEAN' if pre_sig == 0 else f'{pre_sig}/{pre_tot} sig'
    h2 = irf.get(2)
    h2_str = f"β={h2['coef']:.4f} (p={h2['p']:.4f})" if h2 else 'n/a'
    print(f"  Pre-trend: {status}  |  h=2: {h2_str}")
    return pre_sig, pre_tot


def approach_a_demographics(df, y_var, y_label, var_type):
    """Approach A: Add recipient Z₁, Z₂, Z₃ as controls."""
    controls_a = BASE_CONTROLS + ['Z_1', 'Z_2', 'Z_3']
    return run_irf(df, y_var, var_type,
                   ['log_predicted_demo_inflows'], controls_a,
                   'log_predicted_demo_inflows')


def approach_b_demean(df, y_var, y_label, var_type):
    """Approach B: De-mean all variables within country."""
    df_dm = df.copy()

    # De-mean all continuous variables
    vars_to_dm = ['log_predicted_demo_inflows'] + BASE_CONTROLS + [y_var]
    vars_avail = [v for v in vars_to_dm if v in df_dm.columns]

    for v in vars_avail:
        country_mean = df_dm.groupby('iso3')[v].transform('mean')
        df_dm[v] = df_dm[v] - country_mean

    return run_irf(df_dm, y_var, var_type,
                   ['log_predicted_demo_inflows'], BASE_CONTROLS,
                   'log_predicted_demo_inflows')


def approach_c_residual(df, y_var, y_label, var_type):
    """
    Approach C: Purge outcome of demographic effects, then run LP on residual.
    Step 1: y = α + γZ + ε → residual ε̂
    Step 2: LP on ε̂_{t+h}
    """
    # Step 1: Residualize the outcome
    cols = [y_var, 'Z_1', 'Z_2', 'Z_3', 'iso3', 'year']
    sub = df[[c for c in cols if c in df.columns]].dropna()

    if len(sub) < 50:
        return {}

    z_vars = [v for v in ['Z_1', 'Z_2', 'Z_3'] if v in sub.columns]
    gls = PanelGLS()
    try:
        gls.fit(sub[y_var].values, sub[z_vars].values,
                sub['iso3'].values, sub['year'].values)
    except Exception:
        return {}

    # Create residual variable
    resid_var = f'{y_var}_resid'
    sub[resid_var] = gls.resid

    # Merge back
    df_r = df.merge(sub[['iso3', 'year', resid_var]], on=['iso3', 'year'], how='left')

    # Step 2: LP on residual (still a growth variable since we residualized growth)
    return run_irf(df_r, resid_var, var_type,
                   ['log_predicted_demo_inflows'], BASE_CONTROLS,
                   'log_predicted_demo_inflows')


def main():
    print("=" * 70)
    print("PHASE 11: PRE-TREND DIAGNOSIS")
    print("=" * 70)

    df = pd.read_csv(DATA / "deepening_panel.csv")
    df_oecd = df[df['iso3'].isin(OECD)].copy()
    print(f"OECD panel: {len(df_oecd)} obs, {df_oecd['iso3'].nunique()} countries")

    all_results = {}

    for y_var, (y_label, var_type) in OUTCOMES.items():
        if y_var not in df_oecd.columns:
            continue

        print(f"\n{'='*70}")
        print(f"OUTCOME: {y_label}")
        print(f"{'='*70}")

        # Baseline (no fix)
        irf_base = run_irf(df_oecd, y_var, var_type,
                           ['log_predicted_demo_inflows'], BASE_CONTROLS,
                           'log_predicted_demo_inflows')
        pre_base = print_irf(irf_base, 'BASELINE (no fix)')

        # Approach A: + recipient demographics
        irf_a = approach_a_demographics(df_oecd, y_var, y_label, var_type)
        pre_a = print_irf(irf_a, 'APPROACH A: + recipient Z₁, Z₂, Z₃')

        # Approach B: de-mean within country
        irf_b = approach_b_demean(df_oecd, y_var, y_label, var_type)
        pre_b = print_irf(irf_b, 'APPROACH B: de-meaned within country')

        # Approach C: residualized outcome
        irf_c = approach_c_residual(df_oecd, y_var, y_label, var_type)
        pre_c = print_irf(irf_c, 'APPROACH C: outcome residualized on Z')

        all_results[y_var] = {
            'Baseline': irf_base,
            'A: + Demographics': irf_a,
            'B: De-meaned': irf_b,
            'C: Residualized': irf_c,
        }

    # ── Summary ──
    print("\n" + "=" * 70)
    print("DIAGNOSIS SUMMARY")
    print("=" * 70)

    for y_var, (y_label, _) in OUTCOMES.items():
        specs = all_results.get(y_var, {})
        if not specs:
            continue

        print(f"\n  {y_label}:")
        for spec_label, irf in specs.items():
            pre_sig = sum(1 for h in PRE_HORIZONS
                          if irf.get(h) and irf[h]['p'] < 0.1)
            pre_tot = sum(1 for h in PRE_HORIZONS if irf.get(h))
            status = 'CLEAN' if pre_sig == 0 else f'{pre_sig}/{pre_tot} sig'

            h2 = irf.get(2)
            h2_str = f"β={h2['coef']:.4f}{stars(h2['p'])} (p={h2['p']:.4f})" if h2 else 'n/a'

            print(f"    {spec_label:30s}  pre-trend: {status:12s}  h=2: {h2_str}")

    # ── Interpretation ──
    print("\n" + "=" * 70)
    print("INTERPRETATION")
    print("=" * 70)

    for y_var, (y_label, _) in OUTCOMES.items():
        specs = all_results.get(y_var, {})
        if not specs:
            continue

        base_irf = specs.get('Baseline', {})
        a_irf = specs.get('A: + Demographics', {})
        b_irf = specs.get('B: De-meaned', {})
        c_irf = specs.get('C: Residualized', {})

        base_pre = sum(1 for h in PRE_HORIZONS
                       if base_irf.get(h) and base_irf[h]['p'] < 0.1)
        a_pre = sum(1 for h in PRE_HORIZONS
                    if a_irf.get(h) and a_irf[h]['p'] < 0.1)
        b_pre = sum(1 for h in PRE_HORIZONS
                    if b_irf.get(h) and b_irf[h]['p'] < 0.1)
        c_pre = sum(1 for h in PRE_HORIZONS
                    if c_irf.get(h) and c_irf[h]['p'] < 0.1)

        print(f"\n  {y_label}:")
        if base_pre == 0:
            print(f"    No pre-trend issue in baseline. All approaches should be consistent.")
        else:
            if a_pre < base_pre:
                print(f"    Approach A (+ demographics) REDUCES pre-trend ({base_pre}→{a_pre}).")
                print(f"    → Confounding through demographics was part of the problem.")
            else:
                print(f"    Approach A does NOT help ({base_pre}→{a_pre}).")
                print(f"    → Problem is not just demographic confounding.")

            if b_pre < base_pre:
                print(f"    Approach B (de-mean) REDUCES pre-trend ({base_pre}→{b_pre}).")
                print(f"    → Cross-sectional sorting was part of the problem.")
            else:
                print(f"    Approach B does NOT help ({base_pre}→{b_pre}).")
                print(f"    → Within-country time-series correlation is the issue.")

            if c_pre < base_pre:
                print(f"    Approach C (residual) REDUCES pre-trend ({base_pre}→{c_pre}).")
                print(f"    → Demographics explain the pre-trending outcome variation.")
            else:
                print(f"    Approach C does NOT help ({base_pre}→{c_pre}).")

    # ── Save ──
    print("\n--- Saving results ---")
    rows = []
    for y_var, specs in all_results.items():
        for spec_label, irf in specs.items():
            for h, pt in irf.items():
                if pt:
                    rows.append({
                        'outcome': y_var, 'spec': spec_label, 'horizon': h,
                        **pt
                    })

    pd.DataFrame(rows).to_csv(OUT_TABLES / "pretrend_diagnosis_results.csv", index=False)

    with open(OUT_TABLES / "pretrend_diagnosis.md", 'w') as f:
        f.write("# Table 11: Pre-Trend Diagnosis\n\n")
        f.write("Three approaches to address contaminated pre-trends.\n\n")

        for y_var, (y_label, _) in OUTCOMES.items():
            specs = all_results.get(y_var, {})
            if not specs:
                continue

            f.write(f"\n## {y_label}\n\n")
            headers = list(specs.keys())
            f.write(f"| h | {' | '.join(headers)} |\n")
            f.write(f"|---|{'|'.join(['---'] * len(headers))}|\n")

            for h in PRE_HORIZONS + list(range(MAX_HORIZON + 1)):
                cells = []
                for spec in headers:
                    pt = specs[spec].get(h)
                    if pt:
                        cells.append(f"{pt['coef']:.4f}{stars(pt['p'])}")
                    else:
                        cells.append('')
                f.write(f"| {h} | {' | '.join(cells)} |\n")
            f.write("\n")

        f.write("*p<0.1, **p<0.05, ***p<0.01. OECD subsample.*\n")

    print(f"Saved: {OUT_TABLES / 'pretrend_diagnosis_results.csv'}")
    print(f"Saved: {OUT_TABLES / 'pretrend_diagnosis.md'}")
    print("\nPhase 11 complete.")


if __name__ == '__main__':
    main()
